"""
Plot the results from the model

"""
from numpy import linalg as LA
import json, torch, os, cv2
import numpy as np
import torch_geometric as pyg
from matplotlib import pyplot as plt, patches
from torch_geometric.data import InMemoryDataset, Data
from torch import nn
import subprocess as sp
import shlex
import math
import pandas as pd
import time
import matplotlib.cm as cm


class Encoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(Encoder, self).__init__()

        self.output_size = output_size

        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):


        x = self.layers(x)
        return x


class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Decoder, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, x):


        return self.layers(x)


class CN(nn.Module):
    def __init__(self):
        super(CN, self).__init__()
        object_dim = 4  # node features
        relation_dim = 1  # edge features
        effect_dim = 50
        x_external_dim = 0
        self.encoder_model = Encoder(2 * object_dim + relation_dim, effect_dim, 150)
        self.decoder_model = Decoder(object_dim + effect_dim + x_external_dim, 100)

    def forward(self, objects, sender_relations, receiver_relations, relation_info):

        senders = torch.matmul(torch.t(sender_relations.float()), objects.float())
        receivers = torch.matmul(torch.t(receiver_relations.float()), objects.float())
        m = torch.cat((receivers, senders, relation_info), 1)
        effects = self.encoder_model(m.float())
        effect_receivers = torch.matmul(receiver_relations.float(), effects)
        aggregation_result = torch.cat((objects, effect_receivers), 1)
        predicted = self.decoder_model(aggregation_result)
        return predicted

def generate_position_graph(x_list, node_size, total_time_step, folder_path, x_min, x_max):
    """ generate images to be converted to video to verify the data

    :param x_list: position lists containing x position for each time for each sea ice
    :param node_size: the total number of sea floes
    :param total_time_step: the total time steps
    :param folder_path: the folder path that used to save the raw images
    :return:
    """
    for t in range(0, total_time_step, 10):
        fig, ax = plt.subplots(num=1, clear=True)
        fig.set_size_inches(100*0.8, 2*0.8) #width, height 100,4

        ax.set_xlim(0, 200)
        ax.set_ylim(-2, 2)
        plt.axis('off')
        # only one line may be specified; full height
        plt.axvline(x=0, color='black', linewidth=20, label='axvline - full height')
        plt.axvline(x=200, color='black', linewidth=20, label='axvline - full height')

        ax.set_title(f"t = {t*1e-4 :.3f}", x=-0.05, y=0.0, fontsize = 80) #color='c', rotation='vertical',
        colors = cm.rainbow(np.linspace(0, 1, node_size))
        for node in range(node_size):
            circle1 = patches.Circle((x_list[node][t], 0), radius=1, color=colors[node], edgecolor=None)
            ax.add_patch(circle1)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
            plt.savefig(f'{folder_path}/{t}.png',bbox_inches='tight')
        else:
            plt.savefig(f'{folder_path}/{t}.png',bbox_inches='tight')

def generate_video(image_folder, num, total_time, step):
    """ generate video based on the images from image_folder

    :param image_folder: input image folder
    :param num: the label for different situations
    :param total_time: the total time steps
    :param step: jump steps in time
    :return:
    """
    video_name = f'{image_folder}/1D_simulation_{num}.avi'

    # images = [img for img in os.listdir(image_folder) if img.endswith(".png")]

    images = [str(i) + '.png' for i in range(0, total_time, step)]
    frame = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, layers = frame.shape
    height, width = int(height / 5), int(width/5)
    video = cv2.VideoWriter(video_name, 0, 100, (width, height)) # 100 means rendering 100 images per second

    for image in images:
        # let's downscale the image using new  width and height
        down_width = width
        down_height = height
        down_points = (down_width, down_height)

        video.write(cv2.resize(cv2.imread(os.path.join(image_folder, image)), down_points, interpolation=cv2.INTER_LINEAR))

    cv2.destroyAllWindows()
    video.release()


def roll_out_data_2steps_v(data_path, split):
    data_list = []

    if split == 'train':
        total_simulation = 1000
    else:
        total_simulation = 100
    total_time_step = 10000
    node_size = 32
    data_fp = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="c",
                     shape=(total_simulation, total_time_step, node_size, 2))


    for simulation in range(total_simulation):
        temp_position = data_fp[simulation, :, :, 0]
        temp_velocity = data_fp[simulation, :, :, 1]

        data = {"position": temp_position, 'velocity': temp_velocity}

        data_list.append(data)

    return data_list

def rollout_IN(model, data, total_time):
    """ generate the rollout positions based on only initial conditions
    """

    node_size = 32
    num_nodes = 32
    traj = data["position"][:2, :]  # [x_t, num_nodes]
    velocity_traj = data["velocity"][:2, :]

    dt = 1e-4
    device = next(model.parameters()).device

    n_objects = 32
    n_relations = (n_objects - 1) * 2

    # Construct receiver_relations and sender_relations
    receiver_relations = np.zeros((n_objects, n_relations), dtype=float)
    sender_relations = np.zeros((n_objects, n_relations), dtype=float)
    for i in range(1, n_objects - 1):  # assign the non-boundary nodes first (node1 to node 28)
        receiver_relations[i, 2 * i - 2] = 1.0
        receiver_relations[i, 2 * i + 1] = 1.0

        sender_relations[i, 2 * i] = 1.0
        sender_relations[i, 2 * i - 1] = 1.0

    # left boundary
    receiver_relations[0, 1] = 1.0
    sender_relations[0, 0] = 1.0

    # right boundary

    receiver_relations[n_objects - 1, n_relations - 2] = 1.0

    sender_relations[n_objects - 1, n_relations - 1] = 1.0

    sender_relations, receiver_relations = torch.from_numpy(sender_relations), torch.from_numpy(receiver_relations)
    sender_relations = sender_relations.cuda()
    receiver_relations = receiver_relations.cuda()
    total_runtime = 0
    for t in range(2, total_time):
        with torch.no_grad():
            temp_x_1 = traj[-2, :].reshape(num_nodes, 1)
            temp_x_2 = traj[-1, :].reshape(num_nodes, 1)
            velocity = velocity_traj[-1, :].reshape(num_nodes, 1)
            radius = np.ones((node_size, 1))
            temp_x = np.concatenate((temp_x_1, temp_x_2, velocity, radius), axis=1)

            relation_distance = temp_x_2[1:, 0] - temp_x_2[:-1, 0]
            relation_distance2 = np.zeros((node_size - 1, 2))
            relation_distance2[:, 0] = relation_distance
            relation_distance2[:, 1] = -relation_distance
            relation_distance2 = relation_distance2.flatten()  # distance feature
            relation_distance2 = relation_distance2.reshape(2 * (node_size - 1), 1)

            edge_features = relation_distance2  # [num_edges, num_edge_features]

            graph = Data(x=torch.from_numpy(temp_x).float(), edge_attr=torch.from_numpy(edge_features).float())
            graph = graph.to(device)
            temp_start_time = time.time()
            new_velocity = model(graph.x, sender_relations, receiver_relations, graph.edge_attr).cpu()

            new_velocity[0] = 0
            new_velocity[-1] = 0
            new_position = torch.tensor(temp_x_2) + new_velocity * dt
            new_position[0] = -1
            new_position[-1] = 201
            new_position = new_position.reshape(-1, num_nodes)
            traj = torch.cat((torch.tensor(traj), new_position), dim=0)
            velocity_traj = torch.cat((torch.tensor(velocity_traj), new_velocity.reshape(1,-1)), dim=0)
            temp_time = time.time() - temp_start_time
            total_runtime += temp_time
    print(f"IN MODEL-{total_runtime}- seconds ---")
    return traj


def rmse(predictions, targets):
    """

    :param predictions: vector
    :param targets: vector
    :return:
    """
    return np.sqrt(((predictions - targets) ** 2).mean())

def average_RMSE(predict_traj, true_traj, total_time_steps):
    """ sum all timestep RMSE then / 10000

    :param predict_traj: (total_time_steps, num_nodes)
    :param true_traj: (total_time_steps, num_nodes)
    :param total_time_steps: total_time_steps
    :return:
    """

    rmse_array = np.zeros(total_time_steps)
    for i in range(total_time_steps):
        rmse_array[i] = rmse(predict_traj[i], true_traj[i])

    aver_rmse = rmse_array.mean()
    return aver_rmse, rmse_array
if __name__ == '__main__':


    split_list = ['valid', 'train'] #, 'train'valid

    for split in split_list:
        DATASET_NAME_org = f"CN_30"
        model_path = f"./models/{DATASET_NAME_org}"
        MODEL_NAME = f'{DATASET_NAME_org}_{the optimal checkpoint number}'  # change it to your optimal checkpoint number from the result from the compare_CN_30.py file
        total_time_steps = 10000 #change this to be larger than 10000 to be generalize outside the time domain.
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        OUTPUT_DIR = model_path
        simulator = CN().to(device)
        if torch.cuda.is_available():
            checkpoint = torch.load(
                f"{OUTPUT_DIR}/{MODEL_NAME}.pt")
        else:
            checkpoint = torch.load(f"{OUTPUT_DIR}/{MODEL_NAME}.pt", map_location=torch.device('cpu'))
        simulator.load_state_dict(checkpoint["model"])
        folder_path = f'./data_IN'
        rollout_dataset = roll_out_data_2steps_v(f'{folder_path}', split)  # ground truth

        simulator.eval()
        PLOT_GRAPH = True

        PLOT_VIDEO = True
        total_loss = 0
        num_nodes = 30

        """ without knowing the position"""
        for i_dataset in range(len(rollout_dataset)):
            rollout_data = rollout_dataset[i_dataset]
            rollout_data["position"] = rollout_data["position"]


            start_time = time.time()
            rollout_out = rollout_IN(simulator, rollout_data, total_time_steps)
            print("--- %s seconds ---" % (time.time() - start_time))
            DATASET_NAME = f'{DATASET_NAME_org}_infer_plot'

            rollout_out = rollout_out[:, 1:-1]  # remove boundary
            rollout_data["position"] = rollout_data["position"][:, 1:-1]

            if (not np.isnan(rollout_out).any()) and (
            not (rollout_out.numpy() > 200).any()):  # check not none and not inf isinf
                IN_RMSE, IN_RMSE_array = average_RMSE(rollout_out, rollout_data["position"], total_time_steps)
                ################RMSE for position##################################
                fig, ax = plt.subplots(num=1, clear=True)
                # Add some text for labels, title and custom x-axis tick labels, etc.
                ax.set_ylabel('RMSE')

                ax.plot(IN_RMSE_array, label=f'Propsed Model')
                ax.legend(loc='upper right')
                # ax.set_ylim(0, 1)
                folder_path = f'./img/{DATASET_NAME}/rmse'
                if not os.path.exists(folder_path):
                    os.makedirs(folder_path)
                    plt.savefig(f'{folder_path}/{i_dataset}_rmse_position_{split}.png')
                else:
                    plt.savefig(f'{folder_path}/{i_dataset}_rmse_position_{split}.png')

                if PLOT_VIDEO:
                    ground_truth_x = rollout_data["position"]
                    predicted_x = rollout_out.numpy()

                    # plot for video
                    ground_truth_x = ground_truth_x.reshape(total_time_steps, num_nodes)
                    ground_truth_x = ground_truth_x.T
                    predicted_x = predicted_x.reshape(total_time_steps, num_nodes)
                    predicted_x = predicted_x.T

                    image_folder = f'./img/{DATASET_NAME}/video/{i_dataset}/{split}/predicted_x'
                    if not os.path.exists(image_folder):
                        os.makedirs(image_folder)
                    generate_position_graph(predicted_x, num_nodes, total_time_steps, image_folder,
                                            predicted_x.min(), predicted_x.max())
                    generate_video(image_folder, i_dataset, total_time_steps, 10)

                    # Ground truth
                    image_folder = f'./img/{DATASET_NAME}/video/{i_dataset}/{split}/ground_truth_x'
                    if not os.path.exists(image_folder):
                        os.makedirs(image_folder)
                    generate_position_graph(ground_truth_x, num_nodes, total_time_steps, image_folder,
                                            ground_truth_x.min(), ground_truth_x.max())
                    generate_video(image_folder, i_dataset, total_time_steps, 10)

                if PLOT_GRAPH:

                    ################L2 norm for all nodes##################################
                    fig, ax = plt.subplots(num=1, clear=True)
                    # Add some text for labels, title and custom x-axis tick labels, etc.
                    ax.set_ylabel('L2 norm')
                    ax.set_xlabel('t')
                    ax.set_title(f'L2 Norm in different time step')
                    L2norm2 = np.zeros((total_time_steps, 1))
                    for t in range(total_time_steps):
                        temp_norm = rollout_data["position"][t, :] - rollout_out[t, :].numpy()
                        L2norm2[t, 0] = LA.norm(temp_norm, 2)  # L2 norm

                    ax.plot(L2norm2, label=f'all floes')
                    ax.legend(loc='upper right')

                    folder_path = f'./img/{DATASET_NAME}/L2'
                    if not os.path.exists(folder_path):
                        os.makedirs(folder_path)
                        plt.savefig(f'{folder_path}/{i_dataset}_all_{split}.png')
                    else:
                        plt.savefig(f'{folder_path}/{i_dataset}_all_{split}.png')

                    ###############L2 norm for two nodes#################
                    fig, ax = plt.subplots(num=1, clear=True)
                    # Add some text for labels, title and custom x-axis tick labels, etc.
                    ax.set_ylabel('L2 norm')
                    ax.set_xlabel('t')
                    ax.set_title(f'L2 Norm in different time step')
                    L2norm2 = np.zeros((total_time_steps, 1))
                    for t in range(total_time_steps):
                        temp_norm = rollout_data["position"][t, 0] - rollout_out[t, 0].numpy()
                        L2norm2[t, 0] = LA.norm(temp_norm.reshape(1, 1), 2)  # L2 norm

                    ax.plot(L2norm2, label=f'X1')
                    ax.legend(loc='upper right')

                    folder_path = f'./img/{DATASET_NAME}/L2'
                    if not os.path.exists(folder_path):
                        os.makedirs(folder_path)
                        plt.savefig(f'{folder_path}/{i_dataset}_x1_{split}.png')
                    else:
                        plt.savefig(f'{folder_path}/{i_dataset}_x1_{split}.png')

                    fig, ax = plt.subplots(num=1, clear=True)
                    # Add some text for labels, title and custom x-axis tick labels, etc.
                    ax.set_ylabel('L2 norm')
                    ax.set_xlabel('t')
                    ax.set_title(f'L2 Norm in different time step')
                    L2norm2 = np.zeros((total_time_steps, 1))
                    for t in range(total_time_steps):
                        temp_norm = rollout_data["position"][t, 1] - rollout_out[t, 1].numpy()
                        L2norm2[t, 0] = LA.norm(temp_norm.reshape(1, 1), 2)  # L2 norm

                    ax.plot(L2norm2, label=f'X2')
                    ax.legend(loc='upper right')

                    folder_path = f'./img/{DATASET_NAME}/L2'
                    if not os.path.exists(folder_path):
                        os.makedirs(folder_path)
                        plt.savefig(f'{folder_path}/{i_dataset}_x2_{split}.png')
                    else:
                        plt.savefig(f'{folder_path}/{i_dataset}_x2_{split}.png')

                    ###############L2 norm for two nodes#################

                    fig, ax = plt.subplots(num=1, clear=True)
                    # Add some text for labels, title and custom x-axis tick labels, etc.
                    ax.set_ylabel('x')
                    ax.set_title(f'(x_predict - x_truth)/ x_truth in different time step')
                    for node in range(num_nodes):
                        ax.plot(np.absolute(rollout_data["position"][:, node] - rollout_out[:, node].numpy()) /
                                rollout_data["position"][:, node], label=f'X{node}')
                    ax.legend(loc='upper right')
                    # ax.set_ylim(0, 1)
                    folder_path = f'./img/{DATASET_NAME}/relative_errors'
                    if not os.path.exists(folder_path):
                        os.makedirs(folder_path)
                        plt.savefig(f'{folder_path}/{i_dataset}_all_{split}.png')
                    else:
                        plt.savefig(f'{folder_path}/{i_dataset}_all_{split}.png')

                    fig, ax = plt.subplots(num=1, clear=True)
                    # Add some text for labels, title and custom x-axis tick labels, etc.
                    ax.set_ylabel('x')
                    ax.set_title(f'(x_predict - x_truth)/ x_truth in different time step')

                    ax.plot(np.absolute(rollout_data["position"][:, 0] - rollout_out[:, 0].numpy()) /
                            rollout_data["position"][:, 0], label=f'X1')
                    ax.legend(loc='upper right')

                    folder_path = f'./img/{DATASET_NAME}/relative_errors'
                    if not os.path.exists(folder_path):
                        os.makedirs(folder_path)
                        plt.savefig(f'{folder_path}/{i_dataset}_x1_{split}.png')
                    else:
                        plt.savefig(f'{folder_path}/{i_dataset}_x1_{split}.png')


            else:
                print(f'the fail case is {i_dataset} in the {split} dataset')



